{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02152.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02152.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZFYokuXDWzD3",
        "colab_type": "text"
      },
      "source": [
        "## 6.5 循环神经网络的简洁实现"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RSwabfDgXCbS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import math \n",
        "import numpy as np \n",
        "import torch\n",
        "from torch import nn, optim \n",
        "import torch.nn.functional as F \n",
        "import d2l \n",
        "import time \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "87sb195LkBcY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!mkdir ../../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "91eLoC67kN1R",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "a3b4eeb1-9fb5-4730-a2cc-d83c78aaff3f"
      },
      "source": [
        "!git clone https://github.com/ShusenTang/Dive-into-DL-PyTorch.git"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'Dive-into-DL-PyTorch'...\n",
            "remote: Enumerating objects: 1692, done.\u001b[K\n",
            "remote: Total 1692 (delta 0), reused 0 (delta 0), pack-reused 1692\u001b[K\n",
            "Receiving objects: 100% (1692/1692), 25.29 MiB | 33.54 MiB/s, done.\n",
            "Resolving deltas: 100% (975/975), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YN9QWrSpkVD-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!cp Dive-into-DL-PyTorch/data/jaychou_lyrics.txt.zip ../../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "e6bsKEbdkfAu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "548iqsW8k0im",
        "colab_type": "text"
      },
      "source": [
        "### 6.5.1 定义模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "op3H9Pthk53e",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "num_hiddens = 256 \n",
        "rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "biicaMAPlEjY",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9afb12ad-5cb8-4773-dcf5-21850e850c5b"
      },
      "source": [
        "num_steps = 35 \n",
        "batch_size = 2 \n",
        "state = None \n",
        "X = torch.rand(num_steps, batch_size, vocab_size)\n",
        "Y, state_new = rnn_layer(X, state)\n",
        "Y.shape, len(state_new), state_new[0].shape"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(torch.Size([35, 2, 256]), 1, torch.Size([2, 256]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VAJUCruJlrdN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class RNNModel(nn.Module):\n",
        "    def __init__(self, rnn_layer, vocab_size):\n",
        "        super(RNNModel, self).__init__()\n",
        "        self.rnn = rnn_layer \n",
        "        self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1)\n",
        "        self.vocab_size = vocab_size \n",
        "        self.dense = nn.Linear(self.hidden_size, vocab_size)\n",
        "        self.state = None \n",
        "\n",
        "    def forward(self, inputs, state):\n",
        "        X = d2l.to_onehot(inputs, self.vocab_size)\n",
        "        Y, self.state = self.rnn(torch.stack(X), state)\n",
        "        output = self.dense(Y.view(-1, Y.shape[-1]))\n",
        "        return output, self.state"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "es99Lln-mtUd",
        "colab_type": "text"
      },
      "source": [
        "### 6.5.2 训练模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4jzKmtl3mzHY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char, char_to_idx):\n",
        "    state = None \n",
        "    output = [char_to_idx[prefix[0]]]\n",
        "    for t in range(num_chars + len(prefix) - 1):\n",
        "        X = torch.tensor([output[-1]], device=device).view(1, 1)\n",
        "        if state is not None:\n",
        "            if isinstance(state, tuple):\n",
        "                state = (state[0].to(device), state[1].to(device))\n",
        "            else:\n",
        "                state = state.to(device)\n",
        "\n",
        "        (Y, state) = model(X, state)\n",
        "        if t < len(prefix) - 1:\n",
        "            output.append(char_to_idx[prefix[t + 1]])\n",
        "        else:\n",
        "            output.append(int(Y.argmax(dim=1).item()))\n",
        "    return ''.join([idx_to_char[i] for i in output])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SVorY1pjoOkZ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "df345e57-cca8-4775-905b-17ab384ff4b6"
      },
      "source": [
        "model = RNNModel(rnn_layer, vocab_size).to(device)\n",
        "predict_rnn_pytorch('分开', 10, model, vocab_size, device, idx_to_char, char_to_idx)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'分开天辛天天如蝶封如心前'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0iN-JeTuofcn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, \n",
        "                num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes):\n",
        "    loss = nn.CrossEntropyLoss()\n",
        "    optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
        "    model.to(device)\n",
        "    state = None \n",
        "    for epoch in range(num_epochs):\n",
        "        l_sum, n, start = 0.0, 0, time.time()\n",
        "        data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device)\n",
        "        for X, Y in data_iter:\n",
        "            if state is not None:\n",
        "                if isinstance (state, tuple):\n",
        "                    state = (state[0].detach(), state[1].detach())\n",
        "                else:\n",
        "                    state = state.detach()\n",
        "            (output, state) = model(X, state)\n",
        "            y = torch.transpose(Y, 0, 1).contiguous().view(-1)\n",
        "            l = loss(output, y.long())\n",
        "            optimizer.zero_grad()\n",
        "            l.backward()\n",
        "            d2l.grad_clipping(model.parameters(), clipping_theta, device)\n",
        "            optimizer.step()\n",
        "            l_sum += l.item() * y.shape[0]\n",
        "            n += y.shape[0]\n",
        "            \n",
        "        try:\n",
        "            perplexity = math.exp(l_sum / n)\n",
        "        except OverflowError:\n",
        "            perplexity = float('inf')\n",
        "        if (epoch + 1) % pred_period == 0:\n",
        "            print('epoch %d, perplexity %f, time %.2f sec' \n",
        "                % (epoch + 1, perplexity, time.time() - start))\n",
        "            for prefix in prefixes:\n",
        "                print(' -', predict_rnn_pytorch(prefix, pred_len, model, vocab_size, device, \n",
        "                            idx_to_char, char_to_idx))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "96jUgxdVsFL7",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        },
        "outputId": "d1abc62b-2cb3-4146-cc7e-aad92a47d30c"
      },
      "source": [
        "num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2 \n",
        "pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']\n",
        "train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, \n",
        "            num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 50, perplexity 10.818418, time 0.14 sec\n",
            " - 分开始我不  想 我不能再想 我不能再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我\n",
            " - 不分开 我不能再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我\n",
            "epoch 100, perplexity 1.253825, time 0.14 sec\n",
            " - 分开始我不错搞错 拜托 我想是你的脑袋有问题 随便说说 其实我早已经猜透看透不想多说 只是我怕眼泪撑不住\n",
            " - 不分开不能承想你 是你是你 别不了 想要再这样打我妈妈 难道你手不会痛吗 其实我回家就想要阻止一切 让家庭\n",
            "epoch 150, perplexity 1.063882, time 0.14 sec\n",
            " - 分开始我不多痛熬 心伤透看到  什么都会有快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 如果我有轻功 \n",
            " - 不分开不了太多就我 就和你 让它一定的你 原时日记 你的手知后口让我知道 就是开不了口让她知道 就是开不了\n",
            "epoch 200, perplexity 1.031009, time 0.14 sec\n",
            " - 分开 我不多痛熬  我球你爸 我打我妈妈 我说你爸你 打我妈 这样 吗干嘛这样 何必让酒牵鼻子走 瞎 说\n",
            " - 不分开不了太多就我 无和汉常 我想就你已经我的我都会你烦爱  这样没担忧伤 难多 我跟一直到口睡著一口被老\n",
            "epoch 250, perplexity 1.019579, time 0.14 sec\n",
            " - 分开 我不懂 你的黑色幽默 想通 却又再考倒我 说散 你想很久了吧? 我不想拆穿你 当作 是你开的玩笑 \n",
            " - 不分开不了太多就我 一定一 从小到大你在 有的路从 时间变你 心老 看不著去想这样没担忧 唱着歌 一直走 \n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}